{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "exFeYM4KWlz9"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Oj6X6JHoWtVs"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5DZ2c-xfa9m"
      },
      "source": [
        "# フェデレーテッドラーニングリサーチの TFF: モデルと更新圧縮\n",
        "\n",
        "**注意**: この Colab は <a>最新リリースバージョン</a>の <code>tensorflow_federated</code> pip パッケージでの動作が確認されていますが、Tensorflow Federated プロジェクトは現在もプレリリース開発の段階にあるため、`master` では動作しない可能性があります。\n",
        "\n",
        "このチュートリアルでは、[EMNIST](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/emnist) データセットを使用しながら、`tff.learning.build_federated_averaging_process` API と [tensor_encoding](http://jakubkonecny.com/files/tensor_encoding.pdf) API を使用するフェデレーテッドアベレージングアルゴリズムにおける通信コストを削減するために非可逆圧縮アルゴリズムを有効化する方法を実演します。フェデレーテッドアベレージングアルゴリズムの詳細については、論文「[Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)」をご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qrPTFv7ngz-P"
      },
      "source": [
        "## 始める前に\n",
        "\n",
        "始める前に、次のコードを実行し、環境が正しくセットアップされていることを確認してください。挨拶文が表示されない場合は、[インストール](../install.md)ガイドで手順を確認してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X_JnSqDxlw5T"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow_federated\n",
        "!pip install --quiet --upgrade tensorflow-model-optimization\n",
        "\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ctxIBpYIl846"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "from tensorflow_model_optimization.python.core.internal import tensor_encoding as te"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wj-O1cnxKHMw"
      },
      "source": [
        "TFF が動作していることを確認します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8VPepVmfdhHv"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "b'Hello, World!'"
            ]
          },
          "execution_count": 4,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "@tff.federated_computation\n",
        "def hello_world():\n",
        "  return 'Hello, World!'\n",
        "\n",
        "hello_world()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "30Pln72ihL-z"
      },
      "source": [
        "## 入力データを準備する\n",
        "\n",
        "このセクションでは、TFF に含まれる EMNIST データセットを読み込んで事前処理します。EMNIST データセットの詳細は、[画像分類のフェデレーテッドラーニング](https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification#preparing_the_input_data)チュートリアルをご覧ください。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oTP2Dndbl2Oe"
      },
      "outputs": [],
      "source": [
        "# This value only applies to EMNIST dataset, consider choosing appropriate\n",
        "# values if switching to other datasets.\n",
        "MAX_CLIENT_DATASET_SIZE = 418\n",
        "\n",
        "CLIENT_EPOCHS_PER_ROUND = 1\n",
        "CLIENT_BATCH_SIZE = 20\n",
        "TEST_BATCH_SIZE = 500\n",
        "\n",
        "emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(\n",
        "    only_digits=True)\n",
        "\n",
        "def reshape_emnist_element(element):\n",
        "  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])\n",
        "\n",
        "def preprocess_train_dataset(dataset):\n",
        "  \"\"\"Preprocessing function for the EMNIST training dataset.\"\"\"\n",
        "  return (dataset\n",
        "          # Shuffle according to the largest client dataset\n",
        "          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)\n",
        "          # Repeat to do multiple local epochs\n",
        "          .repeat(CLIENT_EPOCHS_PER_ROUND)\n",
        "          # Batch to a fixed client batch size\n",
        "          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)\n",
        "          # Preprocessing step\n",
        "          .map(reshape_emnist_element))\n",
        "\n",
        "emnist_train = emnist_train.preprocess(preprocess_train_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XUQA55yjhTGh"
      },
      "source": [
        "## モデルを定義する\n",
        "\n",
        "ここでは、元の FedAvg CNN に基づいて Keras モデルを定義し、それを [tff.learning.Model](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model) インスタンスにラッピングして TFF が消費できるようにします。\n",
        "\n",
        "モデルのみを直接生成する代わりに、モデルを生成する**関数**が必要となることに注意してください。また、その関数は構築済みのモデルを**キャプチャするだけでなく**、呼び出されるコンテキストで作成する必要があります。これは、TFF がデバイスで利用されるように設計されており、リソースが作られるタイミングを制御することで、キャプチャしてパッケージ化できる必要があるためです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f2dLONjFnE2E"
      },
      "outputs": [],
      "source": [
        "def create_original_fedavg_cnn_model(only_digits=True):\n",
        "  \"\"\"The CNN model used in https://arxiv.org/abs/1602.05629.\"\"\"\n",
        "  data_format = 'channels_last'\n",
        "\n",
        "  max_pool = functools.partial(\n",
        "      tf.keras.layers.MaxPooling2D,\n",
        "      pool_size=(2, 2),\n",
        "      padding='same',\n",
        "      data_format=data_format)\n",
        "  conv2d = functools.partial(\n",
        "      tf.keras.layers.Conv2D,\n",
        "      kernel_size=5,\n",
        "      padding='same',\n",
        "      data_format=data_format,\n",
        "      activation=tf.nn.relu)\n",
        "\n",
        "  model = tf.keras.models.Sequential([\n",
        "      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n",
        "      conv2d(filters=32),\n",
        "      max_pool(),\n",
        "      conv2d(filters=64),\n",
        "      max_pool(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
        "      tf.keras.layers.Dense(10 if only_digits else 62),\n",
        "      tf.keras.layers.Softmax(),\n",
        "  ])\n",
        "\n",
        "  return model\n",
        "\n",
        "# Gets the type information of the input data. TFF is a strongly typed\n",
        "# functional programming framework, and needs type information about inputs to \n",
        "# the model.\n",
        "input_spec = emnist_train.create_tf_dataset_for_client(\n",
        "    emnist_train.client_ids[0]).element_spec\n",
        "\n",
        "def tff_model_fn():\n",
        "  keras_model = create_original_fedavg_cnn_model()\n",
        "  return tff.learning.from_keras_model(\n",
        "      keras_model=keras_model,\n",
        "      input_spec=input_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ipfUaPLEhYYj"
      },
      "source": [
        "## モデルのトレーニングとトレーニングメトリックの出力\n",
        "\n",
        "フェデレーテッドアベレージングアルゴリズムを作成し、定義済みのモデルを EMNIST データセットでトレーニングする準備が整いました。\n",
        "\n",
        "まず、[tff.learning.build_federated_averaging_process](https://www.tensorflow.org/federated/api_docs/python/tff/learning/build_federated_averaging_process) API を使用して、フェデレーテッドアベレージングアルゴリズムを構築する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SAsGGkL9nHEl"
      },
      "outputs": [],
      "source": [
        "federated_averaging = tff.learning.build_federated_averaging_process(\n",
        "    model_fn=tff_model_fn,\n",
        "    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),\n",
        "    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mn1FAPQ32FcV"
      },
      "source": [
        "では、フェデレーテッドアベレージングアルゴリズムを実行しましょう。TFF の観点からフェデレーテッドアベレージングアルゴリズムを実行するには、次のようになります。\n",
        "\n",
        "1. アルゴリズムを初期化し、サーバーの初期状態を取得します。サーバーの状態には、アルゴリズムを実行するために必要な情報が含まれます。TFF は関数型であるため、この状態には、アルゴリズムが使用するオプティマイザの状態（慣性項）だけでなく、モデルパラメータ自体も含まれることを思い出してください。これらは引数として渡され、TFF 計算の結果として返されます。\n",
        "2. ラウンドごとにアルゴリズムを実行します。各ラウンドでは、新しいサーバーの状態が、データでモデルをトレーニングしている各クライアントの結果として返されます。通常、1 つのラウンドでは次のことが発生します。\n",
        "    1. サーバーはすべての参加クライアントにモデルをブロードキャストします。\n",
        "    2. 各クライアントは、モデルとそのデータに基づいて作業を実施します。\n",
        "    3. サーバーはすべてのモデルを集約し、新しいモデルを含むサーバーの状態を生成します。\n",
        "\n",
        "詳細については、[カスタムフェデレーテッドアルゴリズム、パート 2: フェデレーテッドアベレージングの実装](https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_2)チュートリアルをご覧ください。\n",
        "\n",
        "トレーニングメトリックは、トレーニング後に表示できるように、TensorBoard ディレクトリに書き込まれます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "t5n9fXsGOO6-"
      },
      "outputs": [],
      "source": [
        "#@title Load utility functions\n",
        "\n",
        "def format_size(size):\n",
        "  \"\"\"A helper function for creating a human-readable size.\"\"\"\n",
        "  size = float(size)\n",
        "  for unit in ['B','KiB','MiB','GiB']:\n",
        "    if size < 1024.0:\n",
        "      return \"{size:3.2f}{unit}\".format(size=size, unit=unit)\n",
        "    size /= 1024.0\n",
        "  return \"{size:.2f}{unit}\".format(size=size, unit='TiB')\n",
        "\n",
        "def set_sizing_environment():\n",
        "  \"\"\"Creates an environment that contains sizing information.\"\"\"\n",
        "  # Creates a sizing executor factory to output communication cost\n",
        "  # after the training finishes. Note that sizing executor only provides an\n",
        "  # estimate (not exact) of communication cost, and doesn't capture cases like\n",
        "  # compression of over-the-wire representations. However, it's perfect for\n",
        "  # demonstrating the effect of compression in this tutorial.\n",
        "  sizing_factory = tff.framework.sizing_executor_factory()\n",
        "\n",
        "  # TFF has a modular runtime you can configure yourself for various\n",
        "  # environments and purposes, and this example just shows how to configure one\n",
        "  # part of it to report the size of things.\n",
        "  context = tff.framework.ExecutionContext(executor_fn=sizing_factory)\n",
        "  tff.framework.set_default_context(context)\n",
        "\n",
        "  return sizing_factory"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jvH6qIgynI8S"
      },
      "outputs": [],
      "source": [
        "def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):\n",
        "  \"\"\"Trains the federated averaging process and output metrics.\"\"\"\n",
        "  # Create a environment to get communication cost.\n",
        "  environment = set_sizing_environment()\n",
        "\n",
        "  # Initialize the Federated Averaging algorithm to get the initial server state.\n",
        "  state = federated_averaging_process.initialize()\n",
        "\n",
        "  with summary_writer.as_default():\n",
        "    for round_num in range(num_rounds):\n",
        "      # Sample the clients parcitipated in this round.\n",
        "      sampled_clients = np.random.choice(\n",
        "          emnist_train.client_ids,\n",
        "          size=num_clients_per_round,\n",
        "          replace=False)\n",
        "      # Create a list of `tf.Dataset` instances from the data of sampled clients.\n",
        "      sampled_train_data = [\n",
        "          emnist_train.create_tf_dataset_for_client(client)\n",
        "          for client in sampled_clients\n",
        "      ]\n",
        "      # Round one round of the algorithm based on the server state and client data\n",
        "      # and output the new state and metrics.\n",
        "      state, metrics = federated_averaging_process.next(state, sampled_train_data)\n",
        "\n",
        "      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo\n",
        "      size_info = environment.get_size_info()\n",
        "      broadcasted_bits = size_info.broadcast_bits[-1]\n",
        "      aggregated_bits = size_info.aggregate_bits[-1]\n",
        "\n",
        "      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))\n",
        "\n",
        "      # Add metrics to Tensorboard.\n",
        "      for name, value in metrics['train']._asdict().items():\n",
        "          tf.summary.scalar(name, value, step=round_num)\n",
        "\n",
        "      # Add broadcasted and aggregated data size to Tensorboard.\n",
        "      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)\n",
        "      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)\n",
        "      summary_writer.flush()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xp3o3QcBlqY_"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "round  0, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.09433962404727936,loss=2.3181073665618896>>, broadcasted_bits=507.62MiB, aggregated_bits=507.62MiB\n",
            "round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.0765027329325676,loss=2.3148586750030518>>, broadcasted_bits=1015.24MiB, aggregated_bits=1015.24MiB\n",
            "round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.08872458338737488,loss=2.3089394569396973>>, broadcasted_bits=1.49GiB, aggregated_bits=1.49GiB\n",
            "round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.10852713137865067,loss=2.304060220718384>>, broadcasted_bits=1.98GiB, aggregated_bits=1.98GiB\n",
            "round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.10818713158369064,loss=2.3026843070983887>>, broadcasted_bits=2.48GiB, aggregated_bits=2.48GiB\n",
            "round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.10454985499382019,loss=2.300365447998047>>, broadcasted_bits=2.97GiB, aggregated_bits=2.97GiB\n",
            "round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12841254472732544,loss=2.29765248298645>>, broadcasted_bits=3.47GiB, aggregated_bits=3.47GiB\n",
            "round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14023210108280182,loss=2.2977216243743896>>, broadcasted_bits=3.97GiB, aggregated_bits=3.97GiB\n",
            "round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.15060241520404816,loss=2.29490327835083>>, broadcasted_bits=4.46GiB, aggregated_bits=4.46GiB\n",
            "round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.13088512420654297,loss=2.2942349910736084>>, broadcasted_bits=4.96GiB, aggregated_bits=4.96GiB\n"
          ]
        }
      ],
      "source": [
        "# Clean the log directory to avoid conflicts.\n",
        "!rm -R /tmp/logs/scalars/*\n",
        "\n",
        "# Set up the log directory and writer for Tensorboard.\n",
        "logdir = \"/tmp/logs/scalars/original/\"\n",
        "summary_writer = tf.summary.create_file_writer(logdir)\n",
        "\n",
        "train(federated_averaging_process=federated_averaging, num_rounds=10,\n",
        "      num_clients_per_round=10, summary_writer=summary_writer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zwdpTySt7pGQ"
      },
      "source": [
        "上記に示されるルートログディレクトリで TensorBoard を起動すると、トレーニングメトリックが表示されます。データの読み込みには数秒かかることがあります。Loss と Accuracy を除き、ブロードキャストされ集約されたデータの量も出力されます。ブロードキャストされたデータは、各クライアントにサーバーがプッシュしたテンソルで、集約データとは各クライアントがサーバーに返すテンソルを指します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EJ9XQiL-7e1i"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir /tmp/logs/scalars/ --port=0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rY5tWN_5ht6-"
      },
      "source": [
        "## カスタムブロードキャストと集約関数を構築する\n",
        "\n",
        "[tensor_encoding](http://jakubkonecny.com/files/tensor_encoding.pdf) API を使用して、ブロードキャストされたデータと集約データに対して非可逆圧縮アルゴリズムを使用する関数を実装しましょう。\n",
        "\n",
        "まず、2 つの関数を定義します。\n",
        "\n",
        "- `broadcast_encoder_fn`: サーバーのテンソルまたは変数をクライアント通信にエンコードする [te.core.SimpleEncoder](https://github.com/tensorflow/model-optimization/blob/ee53c9a9ae2e18ac1e443842b0b96229f0afb6d6/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py#L30) のインスタンスを作成します（ブロードキャストデータ）。\n",
        "- `mean_encoder_fn`: クライアントのテンソルまたは変数をサーバー通信にエンコードする [te.core.GatherEncoder](https://github.com/tensorflow/model-optimization/blob/ee53c9a9ae2e18ac1e443842b0b96229f0afb6d6/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py#L30) インスタンスを作成します（集約データ）。\n",
        "\n",
        "一度にモデル全体に圧縮メソッドを適用しないことに十分に注意してください。モデルの各変数を圧縮するかどうか、またはどのように圧縮するかは、個別に決定します。これは一般的に、バイアスなどの小さな変数は不正確性により敏感であり、比較的小さいことから、潜在的な通信の節約量も比較的小さくなるためです。そのため、デフォルトでは小さな変数を圧縮しません。この例では、10000 個を超える要素を持つ変数ごとに 8 ビット（256 バケット）の均一量子化を適用し、ほかの変数にのみ ID を適用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lkRHkZTTnKn2"
      },
      "outputs": [],
      "source": [
        "def broadcast_encoder_fn(value):\n",
        "  \"\"\"Function for building encoded broadcast.\"\"\"\n",
        "  spec = tf.TensorSpec(value.shape, value.dtype)\n",
        "  if value.shape.num_elements() > 10000:\n",
        "    return te.encoders.as_simple_encoder(\n",
        "        te.encoders.uniform_quantization(bits=8), spec)\n",
        "  else:\n",
        "    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)\n",
        "\n",
        "\n",
        "def mean_encoder_fn(value):\n",
        "  \"\"\"Function for building encoded mean.\"\"\"\n",
        "  spec = tf.TensorSpec(value.shape, value.dtype)\n",
        "  if value.shape.num_elements() > 10000:\n",
        "    return te.encoders.as_gather_encoder(\n",
        "        te.encoders.uniform_quantization(bits=8), spec)\n",
        "  else:\n",
        "    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "82iYUklQKP2e"
      },
      "source": [
        "TFF は、エンコーダ関数を `tff.learning.build_federated_averaging_process` API が消費できる形式に変換する API を提供しています。`tff.learning.framework.build_encoded_broadcast_from_model` と `tff.learning.framework.build_encoded_mean_from_model` を使用することで、`tff.learning.build_federated_averaging_process` の `broadcast_process` と `aggregation_process` 引数に渡して、非可逆圧縮アルゴリズムでフェデレーテッドアベレージングアルゴリズムを作成するための関数を 2 つ作成することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aqD61hqAGZiW"
      },
      "outputs": [],
      "source": [
        "encoded_broadcast_process = (\n",
        "    tff.learning.framework.build_encoded_broadcast_process_from_model(\n",
        "        tff_model_fn, broadcast_encoder_fn))\n",
        "encoded_mean_process = (\n",
        "    tff.learning.framework.build_encoded_mean_process_from_model(\n",
        "    tff_model_fn, mean_encoder_fn))\n",
        "\n",
        "federated_averaging_with_compression = tff.learning.build_federated_averaging_process(\n",
        "    tff_model_fn,\n",
        "    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),\n",
        "    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),\n",
        "    broadcast_process=encoded_broadcast_process,\n",
        "    aggregation_process=encoded_mean_process)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v3-ADI0hjTqH"
      },
      "source": [
        "## もう一度モデルをトレーニングする\n",
        "\n",
        "では、新しいフェデレーテッドアベレージングアルゴリズムを実行しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0KM_THYdn1yH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "round  0, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.08722109347581863,loss=2.3216357231140137>>, broadcasted_bits=146.46MiB, aggregated_bits=146.46MiB\n",
            "round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.08379272371530533,loss=2.3108291625976562>>, broadcasted_bits=292.92MiB, aggregated_bits=292.92MiB\n",
            "round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.08834951370954514,loss=2.3074147701263428>>, broadcasted_bits=439.38MiB, aggregated_bits=439.39MiB\n",
            "round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.10467479377985,loss=2.305814027786255>>, broadcasted_bits=585.84MiB, aggregated_bits=585.85MiB\n",
            "round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.09853658825159073,loss=2.3012874126434326>>, broadcasted_bits=732.30MiB, aggregated_bits=732.31MiB\n",
            "round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14904330670833588,loss=2.3005223274230957>>, broadcasted_bits=878.77MiB, aggregated_bits=878.77MiB\n",
            "round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.13152804970741272,loss=2.2985599040985107>>, broadcasted_bits=1.00GiB, aggregated_bits=1.00GiB\n",
            "round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12392637878656387,loss=2.297102451324463>>, broadcasted_bits=1.14GiB, aggregated_bits=1.14GiB\n",
            "round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.13289350271224976,loss=2.2944107055664062>>, broadcasted_bits=1.29GiB, aggregated_bits=1.29GiB\n",
            "round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12661737203598022,loss=2.2971296310424805>>, broadcasted_bits=1.43GiB, aggregated_bits=1.43GiB\n"
          ]
        }
      ],
      "source": [
        "logdir_for_compression = \"/tmp/logs/scalars/compression/\"\n",
        "summary_writer_for_compression = tf.summary.create_file_writer(\n",
        "    logdir_for_compression)\n",
        "\n",
        "train(federated_averaging_process=federated_averaging_with_compression, \n",
        "      num_rounds=10,\n",
        "      num_clients_per_round=10,\n",
        "      summary_writer=summary_writer_for_compression)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sE8Bnjel8TIA"
      },
      "source": [
        "もう一度 TensorBoard を起動して、2 つの実行のトレーニングメトリックを比較します。\n",
        "\n",
        "Tensorboard を見てわかるように、`broadcasted_bits` と `aggregated_bits` 図 の `orginial` と `compression` の曲線に大きな減少を確認できます。`loss` と `sparse_categorical_accuracy` 図では、この 2 つの曲線は非常に似通っていました。\n",
        "\n",
        "最後に、元のフェデレーテッドアベレージングアルゴリズムに似たパフォーマンスを達成できる圧縮アルゴリズムを実装しながら、通信コストを大幅に削減することができました。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K9M2_1re28ff"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir /tmp/logs/scalars/ --port=0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jaz9_9H7NUMW"
      },
      "source": [
        "## 演習\n",
        "\n",
        "カスタム圧縮アルゴリズムを実装してトレーニングループに適用するには、次の手順に従います。\n",
        "\n",
        "1. 新しい圧縮アルゴリズムを[ `EncodingStageInterface ` ](https://github.com/tensorflow/model-optimization/blob/ee53c9a9ae2e18ac1e443842b0b96229f0afb6d6/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/encoding_stage.py#L75)のサブクラスとして実装します。または、[この例](https://github.com/tensorflow/model-optimization/blob/ee53c9a9ae2e18ac1e443842b0b96229f0afb6d6/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/encoding_stage.py#L274)のようにより一般的なバリアント<a><code>AdaptiveEncodingStageInterface</code></a>として実装します。\n",
        "2. 新しい[ `Encoder` ](https://github.com/tensorflow/model-optimization/blob/ee53c9a9ae2e18ac1e443842b0b96229f0afb6d6/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/core_encoder.py#L38)を作成し、[モデルブロードキャスト](https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/research/compression/run_experiment.py#L95)または[モデル更新の平均化](https://github.com/tensorflow/federated/blob/e67590f284b487c6b889c070a96c35b8e0341e3b/tensorflow_federated/python/research/compression/run_experiment.py#L95)に特化させます。\n",
        "3. これらのオブジェクトを使用して、[トレーニング計算全体](https://github.com/tensorflow/federated/blob/e67590f284b487c6b889c070a96c35b8e0341e3b/tensorflow_federated/python/research/compression/run_experiment.py#L204)を構築します。\n",
        "\n",
        "潜在的に価値の高いオープンリサーチの問いには、非均一量子化、ハフマンコーディングなどの可逆圧縮、および以前のトレーニングラウンドからの情報に基づいて圧縮を適応させるメカニズムが含まれます。\n",
        "\n",
        "次は、推奨される読み物です。\n",
        "\n",
        "- [Expanding the Reach of Federated Learning by Reducing Client Resource Requirements](https://research.google/pubs/pub47774/)\n",
        "- [Federated Learning: Strategies for Improving Communication Efficiency](https://research.google/pubs/pub45648/)\n",
        "- <a>Advanced and Open Problems in Federated Learning</a> の<em>セクション 3.5 Communication and Compression</em>"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "tff_for_federated_learning_research_compression.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
